from venv import logger
import torch
import gc
import os
import re
import json
import logging
import time
from datetime import datetime
from typing import Optional
from datasets import load_dataset, Dataset
from transformers import AutoModelForCausalLM
from transformers import GenerationConfig
from peft import LoraConfig, get_peft_model
from trl import GRPOConfig, GRPOTrainer
from transformers import AutoConfig, AutoTokenizer
from transformers import TrainerCallback
import csv

def print_gpu_memory(result=None):
    if torch.cuda.is_available():
        print(f"GPU Memory Usage: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
        print(f"GPU Memory Cache: {torch.cuda.memory_reserved() / 1024**2:.2f} MB")
    
    if result is not None:
        print(f"Peak Memory: {result.peak_memory_usage/1024**2:.2f}MB")
        print(f"Inference Time: {result.time.total_seconds():.2f}s")

def clear_gpu_memory():
    gc.collect()
    torch.cuda.empty_cache()

def load_qa_dataset(file_path):
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            data = f.read().split('\n\n')
        
        dataset_dict = {
            "prompt": [],
            "completion": [],
            "original_text": []
        }
        
        system_prompt = {
            "instruction": """    .""",
            
            "examples": """Example 1:

Toxic Example 2:
"""
        }
        
        for qa in data:
            if not qa.strip():
                continue
            try:
                lines = qa.strip().split('\n')
                question = lines[0]
                
                prompt = f"{system_prompt['instruction']}\n\n{system_prompt['examples']}\n\n{question}\n"
                completion = '\n'.join(lines[1:])
                
                dataset_dict["prompt"].append(prompt)
                dataset_dict["completion"].append(completion)
                dataset_dict["original_text"].append(qa)
                
            except Exception as e:
                print(f"Error processing QA pair: {str(e)}")
                continue
        
        return Dataset.from_dict(dataset_dict)
    
    except Exception as e:
        print(f"Error loading dataset: {str(e)}")
        return None

def reward_func_outcome(completions, **kwargs):
    reward_log_path = kwargs.get('reward_log_path', None)
    step = kwargs.get('step', 0)
    rewards = []
    for idx, completion in enumerate(completions):
        try:
            reward = 0.0
            log_lines = []
            format_ok = all(phrase in completion for phrase in [
                "The last letter of",
                "Concatenating them is",
                "The answer is"
            ])
            if format_ok:
                reward += 0.3
            log_lines.append(f"Sample {idx+1} | Step {step}")
            log_lines.append(f"Completion:\n{completion}")
            log_lines.append(f"Format check: {'OK' if format_ok else 'FAIL'} (+0.3 if OK)")
            answer_match = re.findall(r'The answer is ["\']([^"\']+)["\']', completion)
            correct_answer = kwargs.get('correct_answer', '')
            answer_ok = answer_match and answer_match[-1] == correct_answer
            if answer_ok:
                reward += 0.7
            log_lines.append(f"Correct answer: {correct_answer}, Model answer: {answer_match[-1] if answer_match else 'N/A'} (+0.7 if correct)")
            reasoning_section = completion.split("The answer is")[0]
            harm_penalty = "@_@" in reasoning_section and "<harm>" not in completion
            if harm_penalty:
                reward -= 1.0
            log_lines.append(f"Harmful content penalty: {'-1.0' if harm_penalty else '0'}")
        except Exception as e:
            reward = -0.5
            log_lines.append(f"Exception in reward_func_outcome: {str(e)}")
        reward = max(-1.0, min(1.0, reward))
        log_lines.append(f"Final reward: {reward}")
        log_lines.append("="*40)
        rewards.append(reward)
        if reward_log_path:
            with open(reward_log_path, 'a', encoding='utf-8') as f:
                f.write('\n'.join(log_lines) + '\n')
    return rewards

def reward_func_process(prompts, completions, **kwargs):
    reward_log_path = kwargs.get('reward_log_path', None)
    step = kwargs.get('step', 0)
    rewards = []
    for idx, (prompt, completion) in enumerate(zip(prompts, completions)):
        try:
            reward = 0.0
            log_lines = []
            suspect_ok = "@_@" in prompt and "<suspect>" in completion
            harm_ok = "@_@" in prompt and "<harm>" in completion and "</harm>" in completion
            if suspect_ok:
                reward += 0.8
            if harm_ok:
                reward += 0.2
            format_penalty = completion.count("Concatenating them is") > 1
            if format_penalty:
                reward -= 0.3
            extra_tag_penalty = "@_@" not in prompt and ("<suspect>" in completion or "<harm>" in completion)
            if extra_tag_penalty:
                reward -= 0.5
            log_lines.append(f"Sample {idx+1} | Step {step}")
            log_lines.append(f"Prompt:\n{prompt}")
            log_lines.append(f"Completion:\n{completion}")
            log_lines.append(f"Suspect tag: {'OK' if suspect_ok else 'MISS'} (+0.8 if OK)")
            log_lines.append(f"Harm tag: {'OK' if harm_ok else 'MISS'} (+0.2 if OK)")
            log_lines.append(f"Format penalty: {'-0.3' if format_penalty else '0'}")
            log_lines.append(f"Extra tag penalty: {'-0.5' if extra_tag_penalty else '0'}")
        except Exception as e:
            reward = -0.5
            log_lines.append(f"Exception in reward_func_process: {str(e)}")
        reward = max(-1.0, min(1.0, reward))
        log_lines.append(f"Final reward: {reward}")
        log_lines.append("="*40)
        rewards.append(reward)
        if reward_log_path:
            with open(reward_log_path, 'a', encoding='utf-8') as f:
                f.write('\n'.join(log_lines) + '\n')
    return rewards

def prepare_model_for_training(model):
    model.gradient_checkpointing_enable()
    model.config.use_cache = False
    
    for param in model.parameters():
        param.requires_grad = True
    
    if hasattr(model, 'base_model'):
        for n, p in model.base_model.named_parameters():
            if any(layer in n for layer in ['lora', 'adapter']):
                p.requires_grad = True
    
    model.enable_input_require_grads()
    return model

def setup_logger(output_dir):
    os.makedirs(output_dir, exist_ok=True)
    
    current_time = datetime.now().strftime('%Y%m%d_%H%M%S')
    log_file = os.path.join(output_dir, f'training_{current_time}.log')
    
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s [%(levelname)s] %(message)s',
        handlers=[
            logging.FileHandler(log_file),
            logging.StreamHandler()
        ]
    )
    
    progress_file = os.path.join(output_dir, f'training_progress_{current_time}.csv')
    with open(progress_file, 'w', encoding='utf-8') as f:
        f.write("epoch,step,loss,reward_outcome,reward_process,total_reward,learning_rate\n")
    
    reward_log_dir = os.path.join(output_dir, "reward_logs")
    os.makedirs(reward_log_dir, exist_ok=True)
    reward_log_file = os.path.join(reward_log_dir, f'reward_log_{datetime.now().strftime("%Y%m%d_%H%M%S")}.txt')
    
    return logging.getLogger(), progress_file, reward_log_file

def main():
    qa_file = "/grpo_meterial/letter/anti_mixed_letter_data_100*4.txt"
    model_path = "/models/DeepSeek-R1-Distill-Llama-8B"
    output_dir = "/models/TP-ds-llama-8B-400-letter"

    dataset = load_qa_dataset(qa_file)
    if dataset is None:
        print("Failed to load dataset")
        return

    print(f"Dataset size: {len(dataset)}")

    print("Initial GPU memory status:")
    print_gpu_memory()
    clear_gpu_memory()

    clear_gpu_memory()
    torch.cuda.empty_cache()
    
    torch.cuda.set_device(0)
    torch.cuda.empty_cache()
    torch.backends.cudnn.benchmark = True
    
    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        torch_dtype=torch.float16,
        trust_remote_code=True
    )

    model = model.to("cuda")

    for param in model.parameters():
        if param.device == torch.device("meta"):
            raise ValueError("Model parameters are still on meta device, please check the loading process.")

    print("\nGPU memory status after model loading:")
    print_gpu_memory()

    model.config.use_cache = False
    model.config.pad_token_id = model.config.eos_token_id
    
    lora_config = LoraConfig(
        r=16,
        lora_alpha=32,
        target_modules=[
            "q_proj",
            "k_proj",
            "v_proj",
            "o_proj",
            "gate_proj",
            "up_proj",
            "down_proj"
        ],
        lora_dropout=0.05,
        bias="none",
        task_type="CAUSAL_LM",
        inference_mode=False
    )

    model = get_peft_model(model, lora_config)
    model.train()

    tokenizer = AutoTokenizer.from_pretrained(model_path)
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = 'left'

    generation_config = GenerationConfig(
        max_new_tokens=1024,
        do_sample=True,
        temperature=0.7,
        top_p=0.9,
        pad_token_id=model.config.pad_token_id,
        eos_token_id=model.config.eos_token_id,
        num_return_sequences=2,
        repetition_penalty=1.3,
        use_cache=False,
        return_dict_in_generate=True,
        output_scores=True
    )
    
    model.generation_config = generation_config
    
    training_args = GRPOConfig(
        per_device_train_batch_size=2,
        num_generations=2,
        gradient_checkpointing=True,
        warmup_ratio=0.15,
        warmup_steps=100,
        weight_decay=0.01,
        remove_unused_columns=False,
        push_to_hub=False,
        torch_compile=False,
        gradient_accumulation_steps=4,
        learning_rate=5e-5,
        num_train_epochs=5,
        report_to=[],
        output_dir=output_dir,
        logging_steps=10,
        save_steps=500,
        fp16=True,
        max_grad_norm=0.3,
        disable_tqdm=False,
        optim="adamw_torch",
        lr_scheduler_type="cosine_with_restarts"
    )

    logger, progress_file, reward_log_file = setup_logger(output_dir)
    logger.info("Starting training process")
    logger.info(f"Dataset size: {len(dataset)}")
    
    logger.info("Initial GPU memory status:")
    logger.info(f"GPU Memory Usage: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
    logger.info(f"GPU Memory Cache: {torch.cuda.memory_reserved() / 1024**2:.2f} MB")

    class LoggingCallback(TrainerCallback):
        def __init__(self, logger, progress_file):
            self.logger = logger
            self.progress_file = progress_file
            self.step = 0
        
        def on_init_end(self, args, state, control, **kwargs):
            self.logger.info("Training initialization completed")
            return control
        
        def on_log(self, args, state, control, logs=None, **kwargs):
            if logs is None:
                return
            
            with open(self.progress_file, 'a', encoding='utf-8') as f:
                f.write(f"{logs.get('epoch', 0)},"
                       f"{self.step},"
                       f"{logs.get('loss', 0)},"
                       f"{logs.get('rewards/reward_func_outcome', 0)},"
                       f"{logs.get('rewards/reward_func_process', 0)},"
                       f"{logs.get('reward', 0)},"
                       f"{logs.get('learning_rate', 0)}\n")
            
            if self.step % 100 == 0:
                self.logger.info(
                    f"Step {self.step} | "
                    f"Loss: {logs.get('loss', 0):.4f} | "
                    f"Outcome Reward: {logs.get('rewards/reward_func_outcome', 0):.4f} | "
                    f"Process Reward: {logs.get('rewards/reward_func_process', 0):.4f} | "
                    f"Total Reward: {logs.get('reward', 0):.4f} | "
                    f"LR: {logs.get('learning_rate', 0):.6f}"
                )
            
            self.step += 1

    callbacks = [LoggingCallback(logger, progress_file)]
    
    def reward_func_outcome_with_log(completions, **kwargs):
        return reward_func_outcome(completions, reward_log_path=reward_log_file, **kwargs)
    def reward_func_process_with_log(prompts, completions, **kwargs):
        return reward_func_process(prompts, completions, reward_log_path=reward_log_file, **kwargs)

    training_args.tokenizer = tokenizer
    
    trainer = GRPOTrainer(
        model=model,
        reward_funcs=[reward_func_outcome_with_log, reward_func_process_with_log],
        args=training_args,
        train_dataset=dataset,
        callbacks=callbacks
    )

    if hasattr(model, "decoder"):
        model.decoder = None
    if hasattr(model, "encoder"):
        model.encoder = None
    
    clear_gpu_memory()

    trainer.model.generation_config = generation_config
    trainer.model.config.use_cache = False
    trainer.model.generate = generate_wrapper.__get__(trainer.model)
    
    logger.info("Checking model configuration...")
    logger.info(f"use_cache: {trainer.model.config.use_cache}")
    logger.info(f"device: {next(trainer.model.parameters()).device}")
    logger.info(f"Training data size: {len(dataset)}")
    logger.info(f"batch_size: {training_args.per_device_train_batch_size}")
    logger.info(f"num_generations: {training_args.num_generations}")
    
    if not all(key in dataset.features for key in ['prompt', 'completion']):
        raise ValueError("Dataset missing required fields")
        
    trainer.model.train()
    
    trainable = any(p.requires_grad for p in trainer.model.parameters())
    logger.info(f"Model is trainable: {trainable}")

    try:
        trainer.train()
        logger.info("Training completed")
    except Exception as e:
        logger.error(f"Error during training: {str(e)}")
        raise e

def generate_wrapper(self, input_ids=None, attention_mask=None, **kwargs):
    try:
        if input_ids is not None:
            input_ids = input_ids.to(self.device)
        if attention_mask is not None:
            attention_mask = attention_mask.to(self.device)
        
        was_training = self.training
        self.eval()
        
        outputs = self.base_model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            **kwargs
        )
        
        if was_training:
            self.train()
            
        return outputs
    except Exception as e:
        print(f"Error during generation: {str(e)}")
        raise e


if __name__ == "__main__":
    main()